/**
 * \file dnn/src/cambricon/checksum/checksum_kernel_union4.mlu
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "checksum.mlu.h"
#include "cnsccl.h"
#include "mlu.h"

#define CLUSTER_DIM 4
#define CORE_DIM 4
#define STRIDE 1024
__mlu_entry__ void checksum_kernel_union4(uint32_t* dst, uint32_t* src,
                                          int nr_elems) {
    __nram__ uint32_t sum = 0;
    __nram__ uint32_t val[STRIDE];
    __mlu_shared__ uint32_t partial_sum_send[CORE_DIM];
    __mlu_shared__ uint32_t partial_sum_recv[CLUSTER_DIM];

    int task_stride = STRIDE;
    int start_offset = taskId * task_stride;
    int global_stride = taskDim * task_stride;

    for (int task_offset = start_offset; task_offset < nr_elems;
         task_offset += global_stride) {
        int end_offset = task_offset + task_stride;
        end_offset = end_offset > nr_elems ? nr_elems : end_offset;
        int copy_elems = end_offset - task_offset;
        __memcpy(val, src + task_offset, copy_elems * sizeof(uint32_t),
                 GDRAM2NRAM);
        for (int i = 0; i < copy_elems; i++) {
            sum = sum + val[i] * (task_offset + i + 1);
        }
    }

    partial_sum_send[coreId] = sum;

    __sync_cluster();

    if (coreId == 0) {
        for (int i = 1; i < CORE_DIM; ++i) {
            partial_sum_send[0] += partial_sum_send[i];
        }
    }

    __sync_all();
    cnscclGather((void*)&partial_sum_send, (void*)&partial_sum_recv, 1,
                 cnscclInt, 0);

    __sync_all();

    if (clusterId == 0 && coreId == 0) {
        uint32_t res = 0;
        for (int i = 0; i < CLUSTER_DIM; ++i) {
            res += partial_sum_recv[i];
        }
        dst[0] = res;
    }
}
#undef CLUSTER_DIM
#undef CORE_DIM
#undef STRIDE

// vim: ft=cpp syntax=cpp.doxygen

